{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zRN0MvYFIVmT"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "kNIU45vmIl80"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9PcbvP7Lz1Pn"
      },
      "source": [
        "# Gemma Inference on TPUs\n",
        "This notebook demonstrates how to leverage Google Colab's TPUs for inference with [Gemma](https://ai.google.dev/gemma) , an open-weights Large Language Model (LLM), using the [Flax](https://github.com/google/flax).\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_1]Inference_on_TPU.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "09OSQCd5ebzP"
      },
      "source": [
        "### Connect to a TPU\n",
        "- To connect to a TPU v2, click on the button Connect TPU in the top right-hand corner of the screen."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6wOtGTbWfKX1"
      },
      "source": [
        "Now you can see the TPU devices you have available:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TCXhNGCJexoK"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),\n",
              " TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1),\n",
              " TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0),\n",
              " TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1),\n",
              " TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0),\n",
              " TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1),\n",
              " TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0),\n",
              " TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]"
            ]
          },
          "execution_count": 1,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "import jax\n",
        "\n",
        "jax.devices()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LtzOe_3XY9R5"
      },
      "source": [
        "## Installation"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b_42SyQifbJ2"
      },
      "source": [
        "- To install Gemma you need to use Python 3.10 or higher.\n",
        "- Google Colab typically offers Python 3.6 or later versions as the default runtime environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iq2ebV_6YNiU"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting git+https://github.com/google-deepmind/gemma.git\n",
            "  Cloning https://github.com/google-deepmind/gemma.git to /tmp/pip-req-build-vdzv6aiz\n",
            "  Running command git clone --filter=blob:none --quiet https://github.com/google-deepmind/gemma.git /tmp/pip-req-build-vdzv6aiz\n",
            "  Resolved https://github.com/google-deepmind/gemma.git to commit a24194737dcb54b7392091e9ba772aea1cb68ffb\n",
            "  Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
            "  Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
            "  Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
            "Requirement already satisfied: absl-py<3.0.0,>=2.1.0 in /usr/local/lib/python3.10/dist-packages (from gemma==1.0.0) (2.1.0)\n",
            "Requirement already satisfied: flax>=0.8 in /usr/local/lib/python3.10/dist-packages (from gemma==1.0.0) (0.8.4)\n",
            "Requirement already satisfied: sentencepiece<0.2.0,>=0.1.99 in /usr/local/lib/python3.10/dist-packages (from gemma==1.0.0) (0.1.99)\n",
            "Requirement already satisfied: numpy>=1.22 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (1.25.2)\n",
            "Requirement already satisfied: jax>=0.4.19 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (0.4.26)\n",
            "Requirement already satisfied: msgpack in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (1.0.8)\n",
            "Requirement already satisfied: optax in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (0.1.9)\n",
            "Requirement already satisfied: orbax-checkpoint in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (0.4.4)\n",
            "Requirement already satisfied: tensorstore in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (0.1.45)\n",
            "Requirement already satisfied: rich>=11.1 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (13.7.1)\n",
            "Requirement already satisfied: typing-extensions>=4.2 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (4.12.2)\n",
            "Requirement already satisfied: PyYAML>=5.4.1 in /usr/local/lib/python3.10/dist-packages (from flax>=0.8->gemma==1.0.0) (6.0.1)\n",
            "Requirement already satisfied: ml-dtypes>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.19->flax>=0.8->gemma==1.0.0) (0.2.0)\n",
            "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.19->flax>=0.8->gemma==1.0.0) (3.3.0)\n",
            "Requirement already satisfied: scipy>=1.9 in /usr/local/lib/python3.10/dist-packages (from jax>=0.4.19->flax>=0.8->gemma==1.0.0) (1.11.4)\n",
            "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8->gemma==1.0.0) (3.0.0)\n",
            "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1->flax>=0.8->gemma==1.0.0) (2.18.0)\n",
            "Requirement already satisfied: chex>=0.1.7 in /usr/local/lib/python3.10/dist-packages (from optax->flax>=0.8->gemma==1.0.0) (0.1.86)\n",
            "Requirement already satisfied: jaxlib>=0.1.37 in /usr/local/lib/python3.10/dist-packages (from optax->flax>=0.8->gemma==1.0.0) (0.4.26)\n",
            "Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->flax>=0.8->gemma==1.0.0) (1.7.0)\n",
            "Requirement already satisfied: nest_asyncio in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->flax>=0.8->gemma==1.0.0) (1.6.0)\n",
            "Requirement already satisfied: protobuf in /usr/local/lib/python3.10/dist-packages (from orbax-checkpoint->flax>=0.8->gemma==1.0.0) (3.20.3)\n",
            "Requirement already satisfied: toolz>=0.9.0 in /usr/local/lib/python3.10/dist-packages (from chex>=0.1.7->optax->flax>=0.8->gemma==1.0.0) (0.12.1)\n",
            "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1->flax>=0.8->gemma==1.0.0) (0.1.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.8->gemma==1.0.0) (2024.6.1)\n",
            "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.8->gemma==1.0.0) (6.4.0)\n",
            "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->orbax-checkpoint->flax>=0.8->gemma==1.0.0) (3.19.2)\n",
            "Requirement already satisfied: kaggle in /usr/local/lib/python3.10/dist-packages (1.6.14)\n",
            "Requirement already satisfied: six>=1.10 in /usr/local/lib/python3.10/dist-packages (from kaggle) (1.16.0)\n",
            "Requirement already satisfied: certifi>=2023.7.22 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2024.7.4)\n",
            "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.9.0.post0)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.31.0)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from kaggle) (4.66.4)\n",
            "Requirement already satisfied: python-slugify in /usr/local/lib/python3.10/dist-packages (from kaggle) (8.0.4)\n",
            "Requirement already satisfied: urllib3 in /usr/local/lib/python3.10/dist-packages (from kaggle) (2.0.7)\n",
            "Requirement already satisfied: bleach in /usr/local/lib/python3.10/dist-packages (from kaggle) (6.1.0)\n",
            "Requirement already satisfied: webencodings in /usr/local/lib/python3.10/dist-packages (from bleach->kaggle) (0.5.1)\n",
            "Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.10/dist-packages (from python-slugify->kaggle) (1.3)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->kaggle) (3.7)\n"
          ]
        }
      ],
      "source": [
        "! pip install git+https://github.com/google-deepmind/gemma.git\n",
        "! pip install --user kaggle"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QOzN-gxIYSB4"
      },
      "source": [
        "## Downloading the Gemma Checkpoint\n",
        "\n",
        "Before using [Google Gemma](https://ai.google.dev/gemma) for the first time, you must request access to the model through Kaggle by setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup), or completing the following steps:\n",
        "\n",
        "1. Log in to [Kaggle](https://www.kaggle.com) or create a new Kaggle account if you don't already have one.\n",
        "1. Go to the [Gemma model card](https://www.kaggle.com/models/google/paligemma/), and click **Request Access**.\n",
        "1. Complete the consent form and accept the terms and conditions.\n",
        "\n",
        "To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
        "\n",
        "Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under desired name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n",
        "\n",
        "Set environment variables for Kaggle API credentials."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "likVQiEEYS5X"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "O-sxcasvESaM"
      },
      "outputs": [],
      "source": [
        "import kagglehub\n",
        "\n",
        "VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n",
        "weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n",
        "\n",
        "ckpt_path = os.path.join(weights_dir, VARIANT)\n",
        "vocab_path = os.path.join(weights_dir, 'tokenizer.model')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-jpTUa1YESaM"
      },
      "outputs": [],
      "source": [
        "from gemma import params as params_lib\n",
        "from gemma import sampler as sampler_lib\n",
        "from gemma import transformer as transformer_lib\n",
        "import sentencepiece as spm"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4fDQsC87ESaN"
      },
      "source": [
        "## Start Generating with Your Model\n",
        "\n",
        "Load and prepare your LLM's checkpoint for use with Flax."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "57nMYQ4HESaN"
      },
      "outputs": [],
      "source": [
        "# Load parameters\n",
        "params = params_lib.load_and_format_params(ckpt_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NWJ3UvHXESaN"
      },
      "source": [
        "Load your tokenizer, which you'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "khXrjEF0ESaN"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "True"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "vocab = spm.SentencePieceProcessor()\n",
        "vocab.Load(vocab_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tCRtZMg0ESaN"
      },
      "source": [
        "Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "7InOzQtcESaN"
      },
      "outputs": [],
      "source": [
        "transformer_config=transformer_lib.TransformerConfig.from_params(\n",
        "    params,\n",
        "    cache_size=1024  # Number of time steps in the transformer's cache\n",
        ")\n",
        "transformer = transformer_lib.Transformer(transformer_config)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KaU-X3_jESaN"
      },
      "source": [
        "Finally, build a sampler on top of your model and your tokenizer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "bdstASGrESaN"
      },
      "outputs": [],
      "source": [
        "# Create a sampler with the right param shapes.\n",
        "sampler = sampler_lib.Sampler(\n",
        "    transformer=transformer,\n",
        "    vocab=vocab,\n",
        "    params=params['transformer'],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C1fLns-_ESaN"
      },
      "source": [
        "You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "qA0BhNQvESaN"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prompt:\n",
            "\n",
            " Explain the phenomenon of a solar eclipse.\n",
            " Answer:\n",
            "\n",
            "\n",
            "A solar eclipse occurs when the Moon passes between the Sun and Earth, casting a shadow on Earth. This phenomenon is caused by the relative positions of the Moon, Sun, and Earth.\n",
            "\n",
            "**Here's a step-by-step explanation of how a solar eclipse occurs:**\n",
            "\n",
            "1. **New Moon:** The Moon is positioned between the Sun and Earth, and the Sun's rays are not directly visible from Earth.\n",
            "2. **Waxing Crescent Phase:** As the Moon orbits the Sun, it gradually moves from the new moon phase to the waxing crescent phase. This means that the illuminated portion of the Moon is gradually increasing.\n",
            "3. **First Quarter Phase:** When the Moon is at the first quarter phase, half of its face is illuminated.\n",
            "4. **Waxing Gibbous Phase:** As the Moon continues to orbit the Sun, it moves further away from the Sun, and the illuminated portion of the Moon gradually increases to the waxing gibbous phase. This means that more and more of the Moon is illuminated.\n",
            "5. **Full Moon:** When the Moon is at the full moon phase, the entire face of the Moon is illuminated.\n",
            "6. **Waning Gibbous Phase:** As the Moon moves away from the Sun, it gradually moves back into the waning gibbous phase. This means that the illuminated portion of the Moon is gradually decreasing.\n",
            "7. **Third Quarter Phase:** When the Moon is at the third quarter phase, half\n",
            "\n"
          ]
        }
      ],
      "source": [
        "input_batch = [\n",
        "    \"\\n Explain the phenomenon of a solar eclipse.\",\n",
        "  ]\n",
        "\n",
        "out_data = sampler(\n",
        "    input_strings=input_batch,\n",
        "    total_generation_steps=300,\n",
        "  )\n",
        "\n",
        "for input_string, out_string in zip(input_batch, out_data.text):\n",
        "  print(f\"Prompt:\\n{input_string}\\n Answer:\\n{out_string}\")\n",
        "  print()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "TPU",
    "colab": {
      "name": "[Gemma_1]Inference_on_TPU.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
